import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

from pytorch_lightning.core import LightningModule
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import random
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import TensorDataset, random_split
from tqdm import tqdm
import pytorch_lightning as pl
import torch.nn as nn
import json
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from transformers import AdamW,BertTokenizer,PYTORCH_PRETRAINED_BERT_CACHE,RobertaTokenizer
from collections import OrderedDict
from sklearn.metrics import average_precision_score,precision_recall_curve,label_ranking_loss

from cmlm_model import CMLModel
from nce_model import NCEModel
from data_reader import *
from utils import *
import jsonlines

import time
import pickle

class Net(LightningModule):
    
    def __init__(self,hparams=None,train_loader=None,eval_loader=None):
        super(Net, self).__init__()
        self.hparams = hparams
        
        if hparams is not None:
            self.opts = vars(hparams)
            self.train_loader=train_loader
            self.eval_loader=eval_loader
            
            if hparams.load_checkpoint is not None:
                bert_model=hparams.load_checkpoint
            else:
                bert_model=hparams.bert_model
            

            type_map = {
                "nce" : NCEModel,
                "cmlm" : CMLModel,
            }
            self.model = type_map[self.opts.get("ktltype","trip")].from_pretrained(bert_model, args=hparams,
                                                     cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                                                            'distributed_{}'.format(-1)))
                

    def set_loaders(self,train_loader=None,eval_loader=None):
        self.train_loader=train_loader
        self.eval_loader=eval_loader
        
    def load_checkpoint(self,hparams):
        checkpoint = torch.load(hparams.load_checkpoint)
        print(f"Keys in checkpoint:{checkpoint.keys()}",flush=True)
        self.load_state_dict(checkpoint["state_dict"])
        
            
    def prepare_data(self):
        return
    
    @pl.data_loader
    def train_dataloader(self):
        return self.train_loader
        
    @pl.data_loader
    def val_dataloader(self):
        return self.eval_loader
    
    @pl.data_loader
    def test_dataloader(self):
        return self.eval_loader
        
    def configure_optimizers(self):
            # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.opts['weight_decay']},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.opts['learning_rate'], eps=1e-09)        
        scheduler = self.get_linear_schedule_with_warmup(optimizer,self.opts['warm_up_steps'],len(self.train_loader)*self.opts['epochs'])
        return [optimizer], [scheduler]
    
    def get_linear_schedule_with_warmup(self,optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
        def lr_lambda(current_step):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            return max(
                0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
            )

        return LambdaLR(optimizer, lr_lambda, last_epoch)
    
    def forward(self,batch):
        return self.model.forward(*batch)


    def mlm_forward(self,batch):
        ctx_inputs,question_inputs,ans_inputs,labels,label_mask = batch
        option_nums = ans_inputs.shape[1] if len(ans_inputs.shape)>2 else ans_inputs.shape[0]
        flat_ctx_inputs, flat_q_inputs, flat_ans_inputs, flat_labels = ctx_inputs.view(-1,ctx_inputs.size(-1)), \
        question_inputs.view(-1,question_inputs.size(-1)), ans_inputs.view(-1,ans_inputs.size(-1)), labels.view(-1,labels.size(-1))            
        flat_att_mask = label_mask.view(-1,label_mask.size(-1)) 
        ctx_loss = self.model.forward(flat_ctx_inputs,attention_mask=flat_att_mask,masked_lm_labels=flat_labels,reduce=False)[0]
        ques_loss = self.model.forward(flat_q_inputs,attention_mask=flat_att_mask,masked_lm_labels=flat_labels,reduce=False)[0]
        ans_loss = self.model.forward(flat_ans_inputs,attention_mask=flat_att_mask,masked_lm_labels=flat_labels,reduce=False)[0]
        ctx_loss = ctx_loss.view(-1,option_nums)
        ques_loss = ques_loss.view(-1,option_nums)
        ans_loss =  ans_loss.view(-1,option_nums)
        dist_scores = ctx_loss+ques_loss+ans_loss
        loss = torch.mean(dist_scores)
        return loss

    def nce_forward(self,batch):
        return self.forward(batch)
        
    def training_step(self,batch,batch_idx):
        if self.hparams.ktltype == "cmlm":
            loss = self.mlm_forward(batch)
        elif self.hparams.ktltype == "nce":
            loss = self.nce_forward(batch)
        
        tqdm_dict = {'train_loss': loss}

        output = OrderedDict({
            'loss': loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output
    
    def validation_step(self,batch,batch_idx):
        if self.hparams.ktltype == "cmlm":
            score = self.model.score(batch[0],batch[1],batch[2],batch[3],batch[4])
        elif self.hparams.ktltype == "nce":
            score = self.model.score(batch[0],batch[1],batch[2],batch[3],batch[4],batch[5])

        logits = score
        targets = batch[-1]
        
        preds = np.argmin(logits.detach().cpu().tolist(), axis=1)
        targets = targets.detach().cpu().tolist()
        accuracy = simple_accuracy(preds,targets)
        
        ret= {'acc':torch.tensor(accuracy).cuda(), 'preds': torch.tensor(preds).cuda(), 'targets': torch.tensor(targets).cuda()}
        return ret
        
    def validation_end(self,outputs):
        val_acc_mean = 0
        for output in outputs:
            val_acc = output['acc']

            # reduce manually when using dp
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_acc = torch.mean(val_acc)

            val_acc_mean += val_acc
            
        val_acc_mean /= len(outputs)
        
        tqdm_dict = {'acc': val_acc_mean}
        output_dict = {'acc': val_acc_mean.cpu().tolist()}
        
        with jsonlines.open(os.path.join(self.hparams.output_dir,"val_metrics.jsonl"),'a') as f:
            f.write(output_dict)
        
        result = {'progress_bar': tqdm_dict, 'log': tqdm_dict}
        return result
        
    def test_step(self,batch,batch_idx):
        outputs = self.model.score(*batch)
        logits = outputs[0]
        preds = torch.argmin(logits, axis=1)
        return {'preds': preds}
        
    def test_end(self,outputs):
        preds = None
        for output in outputs:
            logits = output["preds"].tolist()
            if preds is None:
                preds = logits
            else:
                preds = np.append(preds, logits.tolist(), axis=0)
        save_file = self.hparams.output_dir+"/test_predictions.txt"
        with open(save_file,"w") as ofd:
            for l in preds:
                ofd.write(l)
        return preds
        
    @staticmethod
    def add_model_specific_args(parent_parser, root_dir):  # pragma: no cover
        #Avaliable as self.hparams
        return parent_parser
    
    
def main(hparams,prefix='full_'):
    """
    Main training routine specific for this project
    :param hparams:
    """
    opts=vars(hparams)
    
    if hparams.seed is not None:
        random.seed(hparams.seed)
        torch.manual_seed(hparams.seed)
        cudnn.deterministic = True

    if "roberta" not in hparams.bert_model:
        tokenizer= BertTokenizer.from_pretrained(hparams.bert_model)
    else:
        tokenizer= RobertaTokenizer.from_pretrained(hparams.bert_model)
    
    if hparams.evaluate:
        print('Loading from:',hparams.load_checkpoint)
        eval_dataset = create_valdataset(hparams,"dev",
                                            ValDataset(hparams.val_file,tokenizer,hparams=hparams,typet="dev"))
        eval_loader=DataLoader(eval_dataset, batch_size=opts['predict_batch_size'], shuffle=False,) 
        pretrained_model = Net.load_from_checkpoint(hparams.load_checkpoint)
        pretrained_model.cuda()
        pretrained_model.set_loaders(eval_loader,eval_loader)
        trainer = pl.Trainer(
            gpus=hparams.gpus,
            distributed_backend=hparams.distributed_backend,
            use_amp=hparams.fp16)
        trainer.test(pretrained_model)
    else:
        train_dataset = create_tensordataset(hparams,"train",
                                        KTLDataset(file_path=hparams.train_file,tokenizer=tokenizer,hparams=hparams))
        eval_dataset = create_valdataset(hparams,"dev",
                                        ValDataset(hparams.val_file,tokenizer,hparams=hparams))
        train_loader=DataLoader(train_dataset, batch_size=opts['train_batch_size'], shuffle=True,drop_last=True)
        eval_loader=DataLoader(eval_dataset, batch_size=opts['predict_batch_size'], shuffle=False,) 
        
        if hparams.load_checkpoint is None:
            model = Net(hparams,train_loader,eval_loader)
        else:
            print(f"Loading from Checkpoint: {hparams.load_checkpoint}",flush=True)
            model = Net.load_from_checkpoint(hparams.load_checkpoint)
            model.set_loaders(train_loader,eval_loader)
            model.load_checkpoint(hparams)
            model.cuda()

        checkpoint_callback = ModelCheckpoint(
                filepath=hparams.output_dir,
                save_top_k=1,
                verbose=True,
                monitor='acc',
                mode='max',
                prefix=prefix
            )

        early_stopping = EarlyStopping(
            monitor='acc',
            mode='max',
            verbose=True,
            patience=50
        )
        
        resume = hparams.load_checkpoint if hparams.resume else None
        
        trainer = pl.Trainer(
            default_save_path=hparams.output_dir,
            gpus=hparams.gpus,
            max_epochs=hparams.epochs,
            distributed_backend=hparams.distributed_backend,
            use_amp=hparams.fp16,
            checkpoint_callback=checkpoint_callback,
            early_stop_callback=early_stopping,
            val_check_interval=0.333,
            gradient_clip_val=8.0,
            resume_from_checkpoint=resume,
            accumulate_grad_batches=4
#             profiler=True
        )
        trainer.fit(model)
#         trainer.test(model)
        
if __name__ == "__main__":
    net = Net()
    root_dir = os.path.dirname(os.path.realpath(__file__))
    parser = get_argument_parser()    

    # each LightningModule defines arguments relevant to it
    parser = Net.add_model_specific_args(parser, root_dir)
    hyperparams = parser.parse_args()  
    main(hyperparams,"")
    
